#!/usr/bin/python
'''
An example of storing traces to a file using root_trees.py
'''

import numpy as np
import time
import sys
import ROOT
from grand.dataio.root_trees import *

# Check if a filename was provided on the command line
if len(sys.argv) == 2 and sys.argv[1][-5:] == ".root":
    filename = sys.argv[1]
else:
    filename = "stored_data.root"

# Generate random number of traces with random lengths for 10 events, as can be in the real case
event_count = 10
adc_traces = []
traces = []
trace_counts = []
for ev in range(event_count):
    trace_count = np.random.randint(3, 7)
    print(f"{trace_count} traces for event {ev}")
    trace_counts.append(trace_count)
    adc_traces.append([])
    traces.append([])
    for i in range(trace_count):
        # The trace length
        l = np.random.randint(900, 1000)
        # X,Y,Z needed for each trace
        adc_traces[-1].append(
            (
                np.random.randint(-20, 21, l).astype(np.int16),
                np.random.randint(-20, 21, l).astype(np.int16),
                np.random.randint(-20, 21, l).astype(np.int16),
                np.random.randint(-20, 21, l).astype(np.int16),
            )
        )
        traces[-1].append(
            (
                (adc_traces[-1][i][0] * 0.9 / 8192).astype(np.float32),
                (adc_traces[-1][i][1] * 0.9 / 8192).astype(np.float32),
                (adc_traces[-1][i][2] * 0.9 / 8192).astype(np.float32),
                (adc_traces[-1][i][3] * 0.9 / 8192).astype(np.float32),
            )
        )

# ********** Generarte Run Tree ****************
# It needs to be first, so that the Event trees can find it. However, it need some informations from them, so will be filled at the end
trun = TRun()
trun.run_mode = 1
trun.data_source = "dummy"
trun.comment = "Generated by data_storing.py"
trun.run_number = 0
trun.site = "dummy site"
trun.first_event = 0
trun.last_event = event_count
trun.t_bin_size = [0.5]

# Generate dummy infor for each detector unit from all the events
for du in range(np.max(trace_counts)):
    trun.du_id.append(du)
    trun.du_geoid.append([1,2,3])
    trun.du_xyz.append([4,5,6])
    trun.du_type.append("antenna 1.0")
    trun.du_tilt.append([4, 5, 6])
    trun.du_ground_tilt.append([4, 5, 6])
    trun.du_nut.append(du+10)
    trun.du_feb.append(du+100)


trun.fill()
trun.write(filename)
print("Wrote trun")

# ********** ADC Counts ****************
# Create the ADC counts tree
tadccounts = TADC()
tadccounts.comment = "Generated data_storing.py"

# fill the tree with the generated events
for ev in range(event_count):
    tadccounts.run_number = 0
    tadccounts.event_number = ev
    # First data unit in the event
    tadccounts.first_du = 0
    # As the event time add the current time
    tadccounts.time_seconds = int(time.mktime(time.gmtime()))
    # Event nanoseconds 0 for now
    tadccounts.time_nanoseconds = 0
    # Triggered event
    # tadccounts.event_type = 0x8000
    # The number of antennas in the event
    tadccounts.du_count = len(traces[ev])

    # Loop through the event's traces
    du_id = []
    du_seconds = []
    du_nanoseconds = []
    trigger_position = []
    trigger_flag = []
    atm_temperature = []
    atm_pressure = []
    atm_humidity = []
    acceleration_x = []
    acceleration_y = []
    acceleration_z = []
    trace_0 = []
    trace_1 = []
    trace_2 = []
    trace_3 = []
    trace_ch = []
    for i, trace in enumerate(adc_traces[ev]):
        # print(ev,i, len(trace[0]))

        # Dumb values just for filling
        du_id.append(i)
        du_seconds.append(tadccounts.time_seconds)
        du_nanoseconds.append(tadccounts.time_nanoseconds)
        trigger_position.append(i // 2)
        trigger_flag.append(0x8000)
        atm_temperature.append(20 + ev // 2)
        atm_pressure.append(1024 + ev // 2)
        atm_humidity.append(50 + ev // 2)
        acceleration_x.append(ev // 2)
        acceleration_y.append(ev // 3)
        acceleration_z.append(ev // 4)

        # Append list of time traces for channels 0/1/2/3
        trace_ch.append([trace[0] + 1, trace[1] + 1, trace[2] + 1, trace[3] + 4])

    tadccounts.du_id = du_id
    tadccounts.du_seconds = du_seconds
    tadccounts.du_nanoseconds = du_nanoseconds
    tadccounts.trigger_position = trigger_position
    tadccounts.trigger_flag = trigger_flag
    tadccounts.atm_temperature = atm_temperature
    tadccounts.atm_pressure = atm_pressure
    tadccounts.atm_humidity = atm_humidity
    tadccounts.acceleration_x = acceleration_x
    tadccounts.acceleration_y = acceleration_y
    tadccounts.acceleration_z = acceleration_z
    tadccounts.trace_ch = trace_ch

    tadccounts.fill()

# write the tree to the storage
tadccounts.write(filename)
print("Wrote tadccounts")

# ********** Raw Voltage ****************

# Voltage has the same data as ADC counts tree, but recalculated to "real" (usually float) values

# Recalculate ADC counts to voltage, just with a dummy conversion now: 0.9 V is equal to 8192 counts for XiHu data
adc2v = 0.9 / 8192

# Create the Raw Voltage counts tree
trawvoltage = TRawVoltage()
trawvoltage.comment = "Generated data_storing.py"

# ROOT.gInterpreter.GenerateDictionary("vector<vector<vector<Float32_t>>>", "vector")

# fill the tree with the generated events
for ev in range(event_count):
    trawvoltage.run_number = 0
    trawvoltage.event_number = ev
    # First data unit in the event
    trawvoltage.first_du = 0
    # As the event time add the current time
    trawvoltage.time_seconds = int(time.mktime(time.gmtime()))
    # Event nanoseconds 0 for now
    trawvoltage.time_nanoseconds = 0
    # Triggered event
    # trawvoltage.event_type = 0x8000
    # The number of antennas in the event
    trawvoltage.du_count = len(traces[ev])

    # Loop through the event's traces
    du_id = []
    du_seconds = []
    du_nanoseconds = []
    trigger_position = []
    trigger_flag = []
    atm_temperature = []
    atm_pressure = []
    atm_humidity = []
    acceleration_x = []
    acceleration_y = []
    acceleration_z = []
    acceleration = []
    trace_x = []
    trace_y = []
    trace_z = []
    trace_ch = []
    for i, trace in enumerate(traces[ev]):
        # print(ev,i, len(trace[0]))

        # Dumb values just for filling
        du_id.append(i)
        du_seconds.append(trawvoltage.time_seconds)
        du_nanoseconds.append(trawvoltage.time_nanoseconds)
        trigger_position.append(i // 2)
        trigger_flag.append(0x8000)
        atm_temperature.append(20 + ev / 2)
        atm_pressure.append(1024 + ev / 2)
        atm_humidity.append(50 + ev / 2)
        acceleration_x.append(ev / 2)
        acceleration_y.append(ev / 3)
        acceleration_z.append(ev / 4)
        acceleration.append([ev / 2, ev / 3, ev / 4])

        # trace_x.append(trace[0])
        # trace_y.append(trace[1])
        # trace_z.append(trace[2])
        trace_ch.append([trace[0], trace[1], trace[2], trace[3]])

    trawvoltage.du_id = du_id
    trawvoltage.du_seconds = du_seconds
    trawvoltage.du_nanoseconds = du_nanoseconds
    # trawvoltage.trigger_position = trigger_position
    trawvoltage.trigger_flag = trigger_flag
    trawvoltage.atm_temperature = atm_temperature
    trawvoltage.atm_pressure = atm_pressure
    trawvoltage.atm_humidity = atm_humidity
    # trawvoltage.acceleration_x = acceleration_x
    # trawvoltage.acceleration_y = acceleration_y
    # trawvoltage.acceleration_z = acceleration_z
    # ToDo: check if this is stored correctly
    trawvoltage.du_acceleration = acceleration
    # trawvoltage.trace_0 = trace_x
    # trawvoltage.trace_1 = trace_y
    # trawvoltage.trace_2 = trace_z
    # trawvoltage.trace_x = trace_x
    # trawvoltage.trace_y = trace_y
    # trawvoltage.trace_z = trace_z
    # trawvoltage.trace_ch = [trace_x, trace_y, trace_z]
    trawvoltage.trace_ch = trace_ch

    trawvoltage.fill()

# write the tree to the storage
trawvoltage.write(filename)
print("Wrote trawvoltage")

# ********** Voltage ****************

# For now basically takes values from raw voltage, and just reassigns the channels

# Create the Voltage counts tree
tvoltage = TVoltage()
tvoltage.comment = "Generated data_storing.py"

# ROOT.gInterpreter.GenerateDictionary("vector<vector<vector<Float32_t>>>", "vector")

# fill the tree with the generated events
for ev in range(event_count):
    tvoltage.run_number = 0
    tvoltage.event_number = ev
    # First data unit in the event
    tvoltage.first_du = 0
    # As the event time add the current time
    tvoltage.time_seconds = int(time.mktime(time.gmtime()))
    # Event nanoseconds 0 for now
    tvoltage.time_nanoseconds = 0
    # The number of antennas in the event
    tvoltage.du_count = len(traces[ev])

    # Loop through the event's traces
    du_id = []
    du_seconds = []
    du_nanoseconds = []
    trigger_position = []
    trigger_flag = []
    atm_temperature = []
    atm_pressure = []
    atm_humidity = []
    acceleration_x = []
    acceleration_y = []
    acceleration_z = []
    trace_x = []
    trace_y = []
    trace_z = []
    trace_xyz = []
    for i, trace in enumerate(traces[ev]):
        # print(ev,i, len(trace[0]))

        # Dumb values just for filling
        du_id.append(i)
        du_seconds.append(tvoltage.time_seconds)
        du_nanoseconds.append(tvoltage.time_nanoseconds)
        trigger_position.append(i // 2)
        atm_temperature.append(20 + ev / 2)
        atm_pressure.append(1024 + ev / 2)
        atm_humidity.append(50 + ev / 2)
        acceleration_x.append(ev / 2)
        acceleration_y.append(ev / 3)
        acceleration_z.append(ev / 4)

        trace_xyz.append([trace[0], trace[1], trace[2]])

    tvoltage.du_id = du_id
    tvoltage.du_seconds = du_seconds
    tvoltage.du_nanoseconds = du_nanoseconds
    # tvoltage.trigger_position = trigger_position
    tvoltage.trigger_flag = trigger_flag
    # tvoltage.acceleration_x = acceleration_x
    # tvoltage.acceleration_y = acceleration_y
    # tvoltage.acceleration_z = acceleration_z
    tvoltage.du_acceleration = [acceleration_x, acceleration_y, acceleration_z]
    # tvoltage.trace_x = trace_x
    # tvoltage.trace_y = trace_y
    # tvoltage.trace_z = trace_z
    # tvoltage.trace = [trace_x, trace_y, trace_z]
    tvoltage.trace = trace_xyz

    tvoltage.fill()

# write the tree to the storage
tvoltage.write(filename)
print("Wrote tvoltage")

# ********** Efield ****************

# Efield has some of the Voltage tree data + FFTs
from scipy import fftpack

# Recalculate Voltage to Efield - just an example, so just multiply by a dumb value
# Here the GRANDlib Efield computation function with antenna model should be used
v2ef = 1.17

# Create the ADC counts tree
tefield = TEfield()
tefield.comment = "Generated data_storing.py"

# fill the tree with every second of generated events - dumb selection
for ev in range(0, event_count, 2):
    tefield.run_number = 0
    tefield.event_number = ev
    # Unix time corresponding to the GPS seconds of the trigger
    tefield.time_seconds = int(time.mktime(time.gmtime()))
    # GPS nanoseconds corresponding to the trigger of the first triggered station
    # Event nanoseconds 0 for now
    tefield.time_nanoseconds = 0
    # Triggered event
    tefield.event_type = 0x8000
    # The number of antennas in the event
    tefield.du_count = len(traces[ev])

    # Loop through the event's traces
    du_id = []
    du_seconds = []
    du_nanoseconds = []
    trigger_position = []
    trigger_flag = []
    atm_temperature = []
    atm_pressure = []
    atm_humidity = []
    trace_xs = []
    trace_ys = []
    trace_zs = []
    fft_mag_xs = []
    fft_mag_ys = []
    fft_mag_zs = []
    fft_phase_xs = []
    fft_phase_ys = []
    fft_phase_zs = []
    trace_xyz = []
    fft_mag_xyz = []
    fft_phase_xyz = []

    for i, trace in enumerate(traces[ev]):
        # print(ev,i, len(trace[0]))

        # Dumb values just for filling
        du_id.append(i)
        du_seconds.append(tefield.time_seconds)
        du_nanoseconds.append(tefield.time_nanoseconds)
        trigger_position.append(i // 2)
        trigger_flag.append(tefield.event_type)
        atm_temperature.append(20 + ev / 2)
        atm_pressure.append(1024 + ev / 2)
        atm_humidity.append(50 + ev / 2)

        # To multiply a list by a number elementwise, convert to a numpy array and back
        # Here a real ComputeEfield() function should be called instead of multiplying adc2v

        # trace_xs.append((np.array(trace[0]) * v2ef).astype(np.float32).tolist())
        # trace_ys.append((np.array(trace[1]) * v2ef).astype(np.float32).tolist())
        # trace_zs.append((np.array(trace[2]) * v2ef).astype(np.float32).tolist())
        #
        # # FFTS
        # fft = fftpack.fft(trace[0])
        # fft_mag_xs.append(np.abs(fft))

        # fft_phase_xs.append(np.abs(fft))
        # fft = fftpack.fft(trace[1])
        # fft_mag_ys.append(np.abs(fft))
        # # ToDo: recall how to calculate the phase easily
        # fft_phase_ys.append(np.abs(fft))
        # fft = fftpack.fft(trace[2])
        # fft_mag_zs.append(np.abs(fft))
        # # ToDo: recall how to calculate the phase easily
        # fft_phase_zs.append(np.abs(fft))

        # ToDo: better read the Voltage trace from the TTree
        trace_xyz.append([(np.array(trace[0]) * v2ef).astype(np.float32).tolist(), (np.array(trace[1]) * v2ef).astype(np.float32).tolist(), (np.array(trace[2]) * v2ef).astype(np.float32).tolist()])
        fft_mag_xyz.append([np.abs(fftpack.fft(trace[0])), np.abs(fftpack.fft(trace[0])), np.abs(fftpack.fft(trace[0]))])
        # ToDo: recall how to calculate the phase easily
        fft_phase_xyz.append([np.abs(fftpack.fft(trace[0]))+10, np.abs(fftpack.fft(trace[0]))+10, np.abs(fftpack.fft(trace[0]))+10])


    tefield.du_id = du_id
    tefield.du_seconds = du_seconds
    tefield.du_nanoseconds = du_nanoseconds
    # tefield.trigger_position = trigger_position
    # tefield.trigger_flag = trigger_flag
    # tefield.atm_temperature = atm_temperature
    # tefield.atm_pressure = atm_pressure
    # tefield.atm_humidity = atm_humidity
    # tefield.trace_x = trace_xs
    # tefield.trace_y = trace_ys
    # tefield.trace_z = trace_zs
    # tefield.trace = [trace_xs, trace_ys, trace_zs]
    # tefield.fft_mag_x = fft_mag_xs
    # tefield.fft_mag_y = fft_mag_ys
    # tefield.fft_mag_z = fft_mag_zs
    # tefield.fft_mag = [fft_mag_xs, fft_mag_ys, fft_mag_zs]
    # tefield.fft_phase_x = fft_phase_xs
    # tefield.fft_phase_y = fft_phase_ys
    # tefield.fft_phase_z = fft_phase_zs
    # tefield.fft_phase = [fft_phase_xs, fft_phase_ys, fft_phase_zs]

    tefield.trace = trace_xyz
    tefield.fft_mag = fft_mag_xyz
    tefield.fft_phase = fft_phase_xyz

    tefield.fill()

# write the tree to the storage, but don't close the file - it will be used for tshower
# ToDo: is this correct? Not sure if I should use the file opened for writing when I am reading
tefield.write(filename, close_file=False)
print("Wrote tefield")

# Generation of shower data for each event - this should be reonstruction, but here just dumb values
tshower = TShower()
tshower.comment = "Generated data_storing.py"
# Loop through all Efield entries
for i in range(tefield.get_entries()):
    # Get the Efield event
    tefield.get_entry(i)

    tshower.run_number = tefield.run_number
    tshower.event_number = tefield.event_number

    tshower.primary_type = "particle"
    tshower.energy_em = np.random.random(1) * 1e8
    tshower.energy_primary = tshower.energy_em*1.2
    tshower.azimuth = np.random.random(1) * 360
    tshower.zenith = np.random.random(1) * 180 - 90
    tshower.shower_core_pos = np.random.random(3)
    tshower.atmos_model = "dense air dummy"
    tshower.atmos_model_param = np.random.random(3)
    tshower.magnetic_field = np.random.random(3)
    tshower.core_alt = 3000.0 + np.random.randint(0, 1000)
    tshower.xmax_grams = np.random.random(1) * 500
    tshower.xmax_pos = np.random.random(3)
    tshower.xmax_pos_shc = tshower.xmax_pos*0.5
    t = datetime.datetime.now().timestamp()
    tshower.core_time_s = int(t)
    tshower.core_time_ns = t-int(t)*1e9

    tshower.fill()

tshower.write(filename)
print("Wrote tshower")

# Need to manually close file if the script is executed in Jupyter
tshower.close_file()

print(f"Finished writing file {filename}")
